clc
clear all 
close all

usgs_sf_read_path = '/gpfsm/dnb05/projects/p15/iau/merra_land/DATA/USGS_StreamFlow/';
cn51_read_path = '/discover/nobackup/projects/geoscm/fzeng/Catchment-CN5.1/GEOSldas_CN51/output/SMAP_EASEv2_M36/cat/ens0000/';

start_year = 1980;
stop_year = 1988;
start_month = 1;
stop_month = 4;

leapYears = 1980:4:2022;
monthLength_nly = [31 28 31 30 31 30 31 31 30 31 30 31];
monthLength_ly = [31 29 31 30 31 30 31 31 30 31 30 31];

% read station information

station_info_file = [usgs_sf_read_path 'Riverflow_Station_Information_NA-M36.nc'];

station_lat = double(ncread(station_info_file,'sta_lat'));
station_lon = double(ncread(station_info_file,'sta_lon'));
comp_basin_area = double(ncread(station_info_file,'CompBasinArea'));  % km^2
comp_basin_area = comp_basin_area*1000000;         % m^2
station_name = double(ncread(station_info_file,'Station_Name'));
station_ID = double(ncread(station_info_file,'Station_ID'));
n_catch_basin = double(ncread(station_info_file,'N_catch_basin'));
global_id = double(ncread(station_info_file,'GlobalID'));
local_id = double(ncread(station_info_file,'LocalID'));
n_cells_basin = double(ncread(station_info_file,'N_cells_basin'));
smap_id = double(ncread(station_info_file,'SMAPID'));
smap_frac = double(ncread(station_info_file,'SMAPFrac'));
smap_i = double(ncread(station_info_file,'I_SMAP'));
smap_j = double(ncread(station_info_file,'J_SMAP'));

for i = 1:length(station_lat)

    tmp_str = [num2str((station_ID(1,i)-48)) ...
                      num2str((station_ID(2,i)-48)) ...
                      num2str((station_ID(3,i)-48)) ...
                      num2str((station_ID(4,i)-48)) ...
                      num2str((station_ID(5,i)-48)) ...
                      num2str((station_ID(6,i)-48)) ...
                      num2str((station_ID(7,i)-48)) ...
                      num2str((station_ID(8,i)-48))];
    station_ID_new{i} = tmp_str;
end

% read model data

mat_index = 0;
streamflow_station_mat_30d = [];
streamflow_station_mat_15d = [];

for y = start_year:stop_year

    y_str = num2str(y,'%04i');

    % define month length based on leap year status
    if sum(y==leapYears)>0
        monthLength = monthLength_ly;
    else
        monthLength = monthLength_nly;
    end

    % loop over months
    for m = 1:12

        m_str = num2str(m,'%02i');

        for d = 1:monthLength(m)

            mat_index = mat_index + 1;

            d_str = num2str(d,'%02i');

            cn51_read_file = [cn51_read_path 'Y' y_str '/M' m_str '/GEOSldas_CN51.tavg24_1d_lnd_Nt.' y_str m_str d_str '_1200z.nc4'];

            if (mat_index==1)

               cn51_lat_vec = double(ncread(cn51_read_file,'lat'));
               cn51_lon_vec = double(ncread(cn51_read_file,'lon'));

               cn51_streamflow_30d_mat = NaN*ones(length(station_lat),(stop_year-start_year+1)*12*31);
               cn51_streamflow_15d_mat = NaN*ones(length(station_lat),(stop_year-start_year+1)*12*31);

            end
             
            runoff_tmp_mat = double(ncread(cn51_read_file,'RUNOFF'));
            baseflow_tmp_mat = double(ncread(cn51_read_file,'BASEFLOW'));
 
            streamflow_station_vec = NaN*ones(size(smap_id,2),1);

            for i = 1:size(smap_id,2)
                data_index = smap_id(:,i);
                data_frac = smap_frac(:,i);
                data_frac(data_index<0) = [];
                data_index(data_index<0) = [];
                streamflow_tmp_vec = runoff_tmp_mat(data_index) + baseflow_tmp_mat(data_index); % streamflow in each tile
                streamflow_tmp_vec = streamflow_tmp_vec.*data_frac; % streamflow in each tile weighted by fraction of tile in basin
                streamflow_sum = nanmean(streamflow_tmp_vec); % basin-wide average streamflow in kg m^-2 s^-1
                streamflow_sum = streamflow_sum*comp_basin_area(i); % basin-wide average streamflow in kg s^-1
                streamflow_station_vec(i) = streamflow_sum;
            end

            streamflow_station_mat_30d = [streamflow_station_mat_30d streamflow_station_vec];
            streamflow_station_mat_15d = [streamflow_station_mat_15d streamflow_station_vec];

            if (size(streamflow_station_mat_30d,2)>30) % only retain previous 30 days for temporal averaging
               streamflow_station_mat_30d(:,1) = [];
            end

            if (size(streamflow_station_mat_15d,2)>15) % only retain previous 15 days for temporal averaging
               streamflow_station_mat_15d(:,1) = [];
            end

            cn51_streamflow_30d_mat(:,mat_index) = nanmean(streamflow_station_mat_30d,2);
            cn51_streamflow_15d_mat(:,mat_index) = nanmean(streamflow_station_mat_15d,2);
            cn51_datetime(:,mat_index) = datetime(y,m,d);

        end
     end
end

cn51_streamflow_30d_mat(:,(mat_index+1):end) = [];
cn51_streamflow_15d_mat(:,(mat_index+1):end) = [];

% convert model streamflow from kg s^-1 to ft^3 s^-1 (water)

cn51_streamflow_30d_mat = cn51_streamflow_30d_mat./28.32;
cn51_streamflow_15d_mat = cn51_streamflow_15d_mat./28.32;

% read station data

usgs_data_dir = [usgs_sf_read_path '2022-11-18/'];
usgs_sf_files = dir(usgs_data_dir);
usgs_sf_files = usgs_sf_files(3:end);

for i = 1%:length(usgs_sf_files)

    read_file = [usgs_data_dir usgs_sf_files(i).name];
    fid = fopen(read_file);
   % data = textscan(fid,'%4c %8c %{yyyy-mm-dd}D %f %s','HeaderLines',31);
    data = textscan(fid,'%4c %8c %s %f %s','HeaderLines',31);
    fclose(fid);

    usgs_station_id = data{2};
    usgs_datetime = data{3};
    for j = 1:length(usgs_datetime)
        usgs_year(j) = str2double(usgs_datetime{j}(1:4));
        usgs_month(j) = str2double(usgs_datetime{j}(6:7));
        usgs_day(j) = str2double(usgs_datetime{j}(9:10));
    end
    usgs_sf_data = data{4};

    for j = 1:length(station_ID_new)
        if strcmp(station_ID_new{j},usgs_station_id(1,:))
           cn51_mat_index = j;
        end 
    end

    start_index =  (usgs_year==cn51_datetime(1).Year) + ...
                  (usgs_month==cn51_datetime(1).Month) + ...
                  (usgs_day==cn51_datetime(1).Day);
    start_index = find(start_index==3);
    stop_index =  (usgs_year==cn51_datetime(end).Year) + ...
                  (usgs_month==cn51_datetime(end).Month) + ...
                  (usgs_day==cn51_datetime(end).Day);
    stop_index = find(stop_index==3);

    usgs_sf_tmp_mat = usgs_sf_data(start_index:stop_index);
    cn51_sf_30d_tmp_mat = cn51_streamflow_30d_mat(cn51_mat_index,:);
    cn51_sf_15d_tmp_mat = cn51_streamflow_15d_mat(cn51_mat_index,:);

    % compute metrics based on daily data

    r_tmp = corrcoef(usgs_sf_tmp_mat,cn51_sf_30d_tmp_mat);
    r_30d(i) = r_tmp(1,2);
    r_tmp = corrcoef(usgs_sf_tmp_mat,cn51_sf_15d_tmp_mat);
    r_15d(i) = r_tmp(1,2);

    mean_usgs_sf = nanmean(usgs_sf_tmp_mat(:));
    mean_cn51_30d = nanmean(cn51_sf_30d_tmp_mat(:));
    mean_cn51_15d = nanmean(cn51_sf_15d_tmp_mat(:));
 
    diff_30d = usgs_sf_tmp_mat(:) - cn51_sf_30d_tmp_mat(:);
    abs_bias_30d(i) = nanmean(abs(diff_30d)); 
    diff_15d = usgs_sf_tmp_mat(:) - cn51_sf_15d_tmp_mat(:);
    abs_bias_15d(i) = nanmean(abs(diff_15d));

    rmse_30d(i) = sqrt(nanmean(diff_30d.^2));
    rmse_15d(i) = sqrt(nanmean(diff_15d.^2));

    r2_30d = 1 - (sum(diff30d.^2)./sum((usgs_sf_tmp_mat - mean_usgs_sf).^2));
    r2_30d = 1 - (sum(diff15d.^2)./sum((usgs_sf_tmp_mat - mean_usgs_sf).^2));
 
end


